from typing import List
import pandas as pd

from interfaces.data_loader import IDataLoader


class PandasDataLoader(IDataLoader):
    def __init__(self, default_source_data: str):
        self.default_source_data = default_source_data

    def load_excel(self, path: str) -> pd.DataFrame:
        return pd.read_excel(path, dtype=str, keep_default_na=False, na_values=[])

    def load_sheet(self, sheet_name: str) -> pd.DataFrame:
        df = pd.read_excel(self.default_source_data, sheet_name=sheet_name, dtype=str, keep_default_na=False, na_values=[])
        row_num_index = df.columns.get_loc("行号") if "行号" in df.columns else None
        if row_num_index is not None:
            cols_after_row_num = df.columns[row_num_index + 1:].tolist()
            return df[["受试者编号"] + cols_after_row_num]
        return df

    def load_specific(self, sheet_name: str, columns: List[str]) -> pd.DataFrame:
        return pd.read_excel(
            io=self.default_source_data,
            sheet_name=sheet_name,
            dtype=str,
            keep_default_na=False,
            na_values=[]
        )[columns]